{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "os.environ['CUDA_VISIBLE_DEVICES'] = '0'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/albert/lamb_optimizer.py:34: The name tf.train.Optimizer is deprecated. Please use tf.compat.v1.train.Optimizer instead.\n",
      "\n"
     ]
    }
   ],
   "source": [
    "from albert import modeling\n",
    "from albert import optimization\n",
    "from albert import tokenization\n",
    "import tensorflow as tf\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/albert/tokenization.py:240: The name tf.logging.info is deprecated. Please use tf.compat.v1.logging.info instead.\n",
      "\n",
      "INFO:tensorflow:loading sentence piece model\n"
     ]
    }
   ],
   "source": [
    "tokenizer = tokenization.FullTokenizer(\n",
    "      vocab_file='albert-base-2020-04-10/sp10m.cased.v10.vocab', do_lower_case=False,\n",
    "      spm_model_file='albert-base-2020-04-10/sp10m.cased.v10.model')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<albert.modeling.AlbertConfig at 0x7f7a3ec15400>"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "bert_config = modeling.AlbertConfig.from_json_file('albert-base-2020-04-10/config.json')\n",
    "bert_config"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "\n",
    "with open('albert-squad-train.pkl', 'rb') as fopen:\n",
    "    train_features, train_examples = pickle.load(fopen)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "max_seq_length = 384\n",
    "doc_stride = 128\n",
    "max_query_length = 64"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "epoch = 5\n",
    "batch_size = 22\n",
    "warmup_proportion = 0.1\n",
    "n_best_size = 20\n",
    "num_train_steps = int(len(train_features) / batch_size * epoch)\n",
    "num_warmup_steps = int(num_train_steps * warmup_proportion)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tensorflow.contrib import layers as contrib_layers\n",
    "\n",
    "class Model:\n",
    "    def __init__(self, is_training = True):\n",
    "        self.X = tf.placeholder(tf.int32, [None, None])\n",
    "        self.segment_ids = tf.placeholder(tf.int32, [None, None])\n",
    "        self.input_masks = tf.placeholder(tf.int32, [None, None])\n",
    "        self.start_positions = tf.placeholder(tf.int32, [None])\n",
    "        self.end_positions = tf.placeholder(tf.int32, [None])\n",
    "        self.p_mask = tf.placeholder(tf.int32, [None, None])\n",
    "        self.is_impossible = tf.placeholder(tf.int32, [None])\n",
    "        \n",
    "        model = modeling.AlbertModel(\n",
    "            config=bert_config,\n",
    "            is_training=is_training,\n",
    "            input_ids=self.X,\n",
    "            input_mask=self.input_masks,\n",
    "            token_type_ids=self.segment_ids,\n",
    "            use_one_hot_embeddings=False)\n",
    "        \n",
    "        final_hidden = model.get_sequence_output()\n",
    "        self.output = final_hidden"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "learning_rate = 2e-5\n",
    "is_training = True\n",
    "\n",
    "tf.reset_default_graph()\n",
    "model = Model(is_training = is_training)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From <ipython-input-16-857e07a2a191>:157: dropout (from tensorflow.python.layers.core) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use keras.layers.dropout instead.\n"
     ]
    }
   ],
   "source": [
    "output = model.output\n",
    "bsz = tf.shape(output)[0]\n",
    "return_dict = {}\n",
    "output = tf.transpose(output, [1, 0, 2])\n",
    "\n",
    "# invalid position mask such as query and special symbols (PAD, SEP, CLS)\n",
    "p_mask = tf.cast(model.p_mask, dtype = tf.float32)\n",
    "\n",
    "# logit of the start position\n",
    "with tf.variable_scope('start_logits'):\n",
    "    start_logits = tf.layers.dense(\n",
    "        output,\n",
    "        1,\n",
    "        kernel_initializer = modeling.create_initializer(\n",
    "            bert_config.initializer_range\n",
    "        ),\n",
    "    )\n",
    "    start_logits = tf.transpose(tf.squeeze(start_logits, -1), [1, 0])\n",
    "    start_logits_masked = start_logits * (1 - p_mask) - 1e30 * p_mask\n",
    "    start_log_probs = tf.nn.log_softmax(start_logits_masked, -1)\n",
    "\n",
    "# logit of the end position\n",
    "with tf.variable_scope('end_logits'):\n",
    "    if is_training:\n",
    "        # during training, compute the end logits based on the\n",
    "        # ground truth of the start position\n",
    "        start_positions = tf.reshape(model.start_positions, [-1])\n",
    "        start_index = tf.one_hot(\n",
    "            start_positions,\n",
    "            depth = max_seq_length,\n",
    "            axis = -1,\n",
    "            dtype = tf.float32,\n",
    "        )\n",
    "        start_features = tf.einsum('lbh,bl->bh', output, start_index)\n",
    "        start_features = tf.tile(\n",
    "            start_features[None], [max_seq_length, 1, 1]\n",
    "        )\n",
    "        end_logits = tf.layers.dense(\n",
    "            tf.concat([output, start_features], axis = -1),\n",
    "            bert_config.hidden_size,\n",
    "            kernel_initializer = modeling.create_initializer(\n",
    "                bert_config.initializer_range\n",
    "            ),\n",
    "            activation = tf.tanh,\n",
    "            name = 'dense_0',\n",
    "        )\n",
    "        end_logits = contrib_layers.layer_norm(\n",
    "            end_logits, begin_norm_axis = -1\n",
    "        )\n",
    "\n",
    "        end_logits = tf.layers.dense(\n",
    "            end_logits,\n",
    "            1,\n",
    "            kernel_initializer = modeling.create_initializer(\n",
    "                bert_config.initializer_range\n",
    "            ),\n",
    "            name = 'dense_1',\n",
    "        )\n",
    "        end_logits = tf.transpose(tf.squeeze(end_logits, -1), [1, 0])\n",
    "        end_logits_masked = end_logits * (1 - p_mask) - 1e30 * p_mask\n",
    "        end_log_probs = tf.nn.log_softmax(end_logits_masked, -1)\n",
    "    else:\n",
    "        # during inference, compute the end logits based on beam search\n",
    "\n",
    "        start_top_log_probs, start_top_index = tf.nn.top_k(\n",
    "            start_log_probs, k = start_n_top\n",
    "        )\n",
    "        start_index = tf.one_hot(\n",
    "            start_top_index,\n",
    "            depth = max_seq_length,\n",
    "            axis = -1,\n",
    "            dtype = tf.float32,\n",
    "        )\n",
    "        start_features = tf.einsum('lbh,bkl->bkh', output, start_index)\n",
    "        end_input = tf.tile(output[:, :, None], [1, 1, start_n_top, 1])\n",
    "        start_features = tf.tile(\n",
    "            start_features[None], [max_seq_length, 1, 1, 1]\n",
    "        )\n",
    "        end_input = tf.concat([end_input, start_features], axis = -1)\n",
    "        end_logits = tf.layers.dense(\n",
    "            end_input,\n",
    "            bert_config.hidden_size,\n",
    "            kernel_initializer = modeling.create_initializer(\n",
    "                bert_config.initializer_range\n",
    "            ),\n",
    "            activation = tf.tanh,\n",
    "            name = 'dense_0',\n",
    "        )\n",
    "        end_logits = contrib_layers.layer_norm(\n",
    "            end_logits, begin_norm_axis = -1\n",
    "        )\n",
    "        end_logits = tf.layers.dense(\n",
    "            end_logits,\n",
    "            1,\n",
    "            kernel_initializer = modeling.create_initializer(\n",
    "                bert_config.initializer_range\n",
    "            ),\n",
    "            name = 'dense_1',\n",
    "        )\n",
    "        end_logits = tf.reshape(\n",
    "            end_logits, [max_seq_length, -1, start_n_top]\n",
    "        )\n",
    "        end_logits = tf.transpose(end_logits, [1, 2, 0])\n",
    "        end_logits_masked = (\n",
    "            end_logits * (1 - p_mask[:, None]) - 1e30 * p_mask[:, None]\n",
    "        )\n",
    "        end_log_probs = tf.nn.log_softmax(end_logits_masked, -1)\n",
    "        end_top_log_probs, end_top_index = tf.nn.top_k(\n",
    "            end_log_probs, k = end_n_top\n",
    "        )\n",
    "        end_top_log_probs = tf.reshape(\n",
    "            end_top_log_probs, [-1, start_n_top * end_n_top]\n",
    "        )\n",
    "        end_top_index = tf.reshape(\n",
    "            end_top_index, [-1, start_n_top * end_n_top]\n",
    "        )\n",
    "        \n",
    "if is_training:\n",
    "    return_dict['start_log_probs'] = start_log_probs\n",
    "    return_dict['end_log_probs'] = end_log_probs\n",
    "else:\n",
    "    return_dict['start_top_log_probs'] = start_top_log_probs\n",
    "    return_dict['start_top_index'] = start_top_index\n",
    "    return_dict['end_top_log_probs'] = end_top_log_probs\n",
    "    return_dict['end_top_index'] = end_top_index\n",
    "\n",
    "# an additional layer to predict answerability\n",
    "with tf.variable_scope('answer_class'):\n",
    "    # get the representation of CLS\n",
    "    cls_index = tf.one_hot(\n",
    "        tf.zeros([bsz], dtype = tf.int32),\n",
    "        max_seq_length,\n",
    "        axis = -1,\n",
    "        dtype = tf.float32,\n",
    "    )\n",
    "    cls_feature = tf.einsum('lbh,bl->bh', output, cls_index)\n",
    "\n",
    "    # get the representation of START\n",
    "    start_p = tf.nn.softmax(\n",
    "        start_logits_masked, axis = -1, name = 'softmax_start'\n",
    "    )\n",
    "    start_feature = tf.einsum('lbh,bl->bh', output, start_p)\n",
    "\n",
    "    # note(zhiliny): no dependency on end_feature so that we can obtain\n",
    "    # one single `cls_logits` for each sample\n",
    "    ans_feature = tf.concat([start_feature, cls_feature], -1)\n",
    "    ans_feature = tf.layers.dense(\n",
    "        ans_feature,\n",
    "        bert_config.hidden_size,\n",
    "        activation = tf.tanh,\n",
    "        kernel_initializer = modeling.create_initializer(\n",
    "            bert_config.initializer_range\n",
    "        ),\n",
    "        name = 'dense_0',\n",
    "    )\n",
    "    ans_feature = tf.layers.dropout(\n",
    "        ans_feature, bert_config.hidden_dropout_prob, training = is_training\n",
    "    )\n",
    "    cls_logits = tf.layers.dense(\n",
    "        ans_feature,\n",
    "        1,\n",
    "        kernel_initializer = modeling.create_initializer(\n",
    "            bert_config.initializer_range\n",
    "        ),\n",
    "        name = 'dense_1',\n",
    "        use_bias = False,\n",
    "    )\n",
    "    cls_logits = tf.squeeze(cls_logits, -1)\n",
    "    \n",
    "return_dict['cls_logits'] = cls_logits"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/tensorflow_core/python/ops/nn_impl.py:183: where (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use tf.where in 2.0, which has the same broadcast rule as np.where\n"
     ]
    }
   ],
   "source": [
    "seq_length = tf.shape(model.X)[1]\n",
    "\n",
    "def compute_loss(log_probs, positions):\n",
    "    one_hot_positions = tf.one_hot(\n",
    "        positions, depth = seq_length, dtype = tf.float32\n",
    "    )\n",
    "\n",
    "    loss = -tf.reduce_sum(one_hot_positions * log_probs, axis = -1)\n",
    "    loss = tf.reduce_mean(loss)\n",
    "    return loss\n",
    "\n",
    "start_loss = compute_loss(\n",
    "    return_dict['start_log_probs'], model.start_positions\n",
    ")\n",
    "end_loss = compute_loss(\n",
    "    return_dict['end_log_probs'], model.end_positions\n",
    ")\n",
    "\n",
    "total_loss = (start_loss + end_loss) * 0.5\n",
    "\n",
    "cls_logits = return_dict['cls_logits']\n",
    "is_impossible = tf.reshape(model.is_impossible, [-1])\n",
    "regression_loss = tf.nn.sigmoid_cross_entropy_with_logits(\n",
    "    labels = tf.cast(is_impossible, dtype = tf.float32),\n",
    "    logits = cls_logits,\n",
    ")\n",
    "regression_loss = tf.reduce_mean(regression_loss)\n",
    "\n",
    "# note(zhiliny): by default multiply the loss by 0.5 so that the scale is\n",
    "# comparable to start_loss and end_loss\n",
    "total_loss += regression_loss * 0.5"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/albert/optimization.py:36: The name tf.train.get_or_create_global_step is deprecated. Please use tf.compat.v1.train.get_or_create_global_step instead.\n",
      "\n",
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/albert/optimization.py:41: The name tf.train.polynomial_decay is deprecated. Please use tf.compat.v1.train.polynomial_decay instead.\n",
      "\n",
      "INFO:tensorflow:++++++ warmup starts at step 0, for 3011 steps ++++++\n",
      "INFO:tensorflow:using adamw\n",
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/albert/optimization.py:101: The name tf.trainable_variables is deprecated. Please use tf.compat.v1.trainable_variables instead.\n",
      "\n"
     ]
    }
   ],
   "source": [
    "optimizer = optimization.create_optimizer(total_loss, learning_rate, \n",
    "                                          num_train_steps, num_warmup_steps, False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Restoring parameters from albert-base-2020-04-10/model.ckpt-400000\n"
     ]
    }
   ],
   "source": [
    "sess = tf.InteractiveSession()\n",
    "sess.run(tf.global_variables_initializer())\n",
    "var_lists = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope = 'bert')\n",
    "saver = tf.train.Saver(var_list = var_lists)\n",
    "saver.restore(sess, 'albert-base-2020-04-10/model.ckpt-400000')\n",
    "\n",
    "# saver = tf.train.Saver(var_list = tf.trainable_variables())\n",
    "# saver.restore(sess, 'bert-base-squad/model.ckpt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 6024/6024 [1:03:01<00:00,  1.59it/s, cost=0.592, end_loss=0.00062, regression_loss=0.653, start_loss=0.53]  \n",
      "train minibatch loop:   0%|          | 0/6024 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 0\n",
      "2.027449\n",
      "2.3524973\n",
      "1.1032976\n",
      "0.5991028\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 6024/6024 [1:03:07<00:00,  1.59it/s, cost=0.624, end_loss=0.000278, regression_loss=0.617, start_loss=0.63] \n",
      "train minibatch loop:   0%|          | 0/6024 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 1\n",
      "1.3526311\n",
      "1.6595553\n",
      "0.5204538\n",
      "0.5252531\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 6024/6024 [1:02:56<00:00,  1.60it/s, cost=0.412, end_loss=7.29e-5, regression_loss=0.434, start_loss=0.391] \n",
      "train minibatch loop:   0%|          | 0/6024 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 2\n",
      "1.1491865\n",
      "1.4022663\n",
      "0.43187308\n",
      "0.4642336\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 6024/6024 [1:02:49<00:00,  1.60it/s, cost=0.289, end_loss=4.85e-5, regression_loss=0.318, start_loss=0.26]  \n",
      "train minibatch loop:   0%|          | 0/6024 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 3\n",
      "0.98238\n",
      "1.1978309\n",
      "0.36663505\n",
      "0.40029392\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 6024/6024 [1:03:19<00:00,  1.59it/s, cost=0.227, end_loss=7.65e-5, regression_loss=0.228, start_loss=0.226]  "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 4\n",
      "0.86556995\n",
      "1.0575606\n",
      "0.32583946\n",
      "0.3477398\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "from tqdm import tqdm\n",
    "\n",
    "for e in range(epoch):\n",
    "    pbar = tqdm(\n",
    "        range(0, len(train_features), batch_size), desc = 'train minibatch loop'\n",
    "    )\n",
    "    costs, start_losses, end_losses, regression_losses = [], [], [], []\n",
    "    for i in pbar:\n",
    "        batch = train_features[i: i + batch_size]\n",
    "        batch_ids = [b.input_ids for b in batch]\n",
    "        batch_masks = [b.input_mask for b in batch]\n",
    "        batch_segment = [b.segment_ids for b in batch]\n",
    "        batch_start = [b.start_position for b in batch]\n",
    "        batch_end = [b.end_position for b in batch]\n",
    "        is_impossible = [b.is_impossible for b in batch]\n",
    "        p_mask = [b.p_mask for b in batch]\n",
    "        cost, start_loss_, end_loss_, regression_loss_, _ = sess.run(\n",
    "            [total_loss, start_loss, end_loss, regression_loss, optimizer],\n",
    "            feed_dict = {\n",
    "                model.start_positions: batch_start,\n",
    "                model.end_positions: batch_end,\n",
    "                model.X: batch_ids,\n",
    "                model.segment_ids: batch_segment,\n",
    "                model.input_masks: batch_masks,\n",
    "                model.is_impossible: is_impossible,\n",
    "                model.p_mask: p_mask\n",
    "            },\n",
    "        );''\n",
    "        pbar.set_postfix(cost = cost, start_loss = start_loss_,\n",
    "                        end_loss = end_loss_, regression_loss = regression_loss_)\n",
    "        costs.append(cost)\n",
    "        start_losses.append(start_loss_)\n",
    "        end_losses.append(end_loss_)\n",
    "        regression_losses.append(regression_loss_)\n",
    "        \n",
    "    print(f'epoch: {e}')\n",
    "    print(np.mean(costs))\n",
    "    print(np.mean(start_losses))\n",
    "    print(np.mean(end_losses))\n",
    "    print(np.mean(regression_losses))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'albert-base-squad/model.ckpt'"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "saver = tf.train.Saver(tf.trainable_variables())\n",
    "saver.save(sess, 'albert-base-squad/model.ckpt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
